{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xHF95Kr4CzGq"
   },
   "source": [
    "# 🤗 Welcome to AdalFlow!\n",
    "## The PyTorch library to auto-optimize any LLM task pipelines\n",
    "\n",
    "Thanks for trying us out, we're here to provide you with the best LLM application development experience you can dream of 😊 any questions or concerns you may have, [come talk to us on discord,](https://discord.gg/ezzszrRZvT) we're always here to help! ⭐ <i>Star us on <a href=\"https://github.com/SylphAI-Inc/AdalFlow\">Github</a> </i> ⭐\n",
    "\n",
    "\n",
    "# Quick Links\n",
    "\n",
    "Github repo: https://github.com/SylphAI-Inc/AdalFlow\n",
    "\n",
    "Full Tutorials: https://adalflow.sylph.ai/index.html#.\n",
    "\n",
    "Deep dive on each API: check out the [developer notes](https://adalflow.sylph.ai/tutorials/index.html).\n",
    "\n",
    "Common use cases along with the auto-optimization:  check out [Use cases](https://adalflow.sylph.ai/use_cases/index.html).\n",
    "\n",
    "## 📖 Outline\n",
    "\n",
    "In this tutorial, we will cover the auto-optimization of a standard RAG:\n",
    "\n",
    "- Introducing HotPotQA dataset and HotPotQAData class.\n",
    "\n",
    "- Convert Dspy’s Retriever to AdalFlow’s Retriever to easy comparison.\n",
    "\n",
    "- Build the standard RAG with Retriever and Generator components.\n",
    "\n",
    "- Learn how to connect the output-input between components to enable auto-text-grad optimization."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Kof5M6DRaKhh"
   },
   "source": [
    "\n",
    "# Installation\n",
    "\n",
    "1. Use `pip` to install the `adalflow` Python package. We will need `openai`, `groq` from the extra packages.\n",
    "\n",
    "  ```bash\n",
    "  pip install adalflow[openai,groq]\n",
    "  ```\n",
    "2. Setup  `openai` and `groq` API key in the environment variables\n",
    "\n",
    "You can choose to use different client. You can import the model client you prefer. We support `Anthropic`, `Cohere`, `Google`, `GROQ`, `OpenAI`, `Transformer` and more in development. We will use OpenAI here as an example.Please refer to our [full installation guide](https://adalflow.sylph.ai/get_started/installation.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "tAp3eDjOCma1"
   },
   "outputs": [],
   "source": [
    "from IPython.display import clear_output\n",
    "\n",
    "!pip install -U adalflow[openai] # also install the package for the model client you'll use\n",
    "!pip install dspy\n",
    "!pip install datasets\n",
    "clear_output()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip uninstall httpx anyio -y\n",
    "!pip install \"anyio>=3.1.0,<4.0\"\n",
    "!pip install httpx==0.24.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KapUyHMM07pJ"
   },
   "source": [
    "## Set Environment Variables\n",
    "\n",
    "Run the following code and pass your api key.\n",
    "\n",
    "Note: for normal `.py` projects, follow our [official installation guide](https://lightrag.sylph.ai/get_started/installation.html).\n",
    "\n",
    "*Go to [OpenAI](https://platform.openai.com/docs/introduction) to get API keys if you don't already have.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ONfzF9Puzdd_",
    "outputId": "5fc0cd30-9ae7-443a-c06c-31e9edeafd69"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Please enter your OpenAI API key: ··········\n",
      "API keys have been set.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "from getpass import getpass\n",
    "\n",
    "# Prompt user to enter their API keys securely\n",
    "openai_api_key = getpass(\"Please enter your OpenAI API key: \")\n",
    "\n",
    "\n",
    "# Set environment variables\n",
    "os.environ[\"OPENAI_API_KEY\"] = openai_api_key\n",
    "\n",
    "print(\"API keys have been set.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "id": "aE3I05BqOmd7"
   },
   "outputs": [],
   "source": [
    "import dspy\n",
    "import re\n",
    "from typing import List, Union, Optional, Dict, Callable, Any, Tuple\n",
    "from dataclasses import dataclass, field\n",
    "import adalflow as adal\n",
    "from adalflow.optim.parameter import Parameter, ParameterType\n",
    "from adalflow.datasets.hotpot_qa import HotPotQA, HotPotQAData\n",
    "from adalflow.datasets.types import Example\n",
    "from adalflow.core.types import RetrieverOutput\n",
    "from adalflow.core import Component, Generator\n",
    "from adalflow.core.retriever import Retriever\n",
    "from adalflow.core.component import fun_to_component\n",
    "from adalflow.components.model_client.openai_client import OpenAIClient"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cqUUoua9fUxQ"
   },
   "outputs": [],
   "source": [
    "gpt_4o_model = {\n",
    "    \"model_client\": OpenAIClient(),\n",
    "    \"model_kwargs\": {\n",
    "        \"model\": \"gpt-4o-mini\",\n",
    "        \"max_tokens\": 2000,\n",
    "    },\n",
    "}\n",
    "\n",
    "gpt_3_model = {\n",
    "    \"model_client\": OpenAIClient(),\n",
    "    \"model_kwargs\": {\n",
    "        \"model\": \"gpt-3.5-turbo\",\n",
    "        \"max_tokens\": 2000,\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "0irHeHUkOmL8",
    "outputId": "61f778a2-9ec1-4fda-daa2-bcc7f31baa78"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "HotPotQAData(id='5a8b57f25542995d1e6f1371', question='Were Scott Derrickson and Ed Wood of the same nationality?', answer='yes', gold_titles=\"{'Scott Derrickson', 'Ed Wood'}\") <class 'adalflow.datasets.types.HotPotQAData'>\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HotPotQAData(id='5a8b57f25542995d1e6f1371', question='Were Scott Derrickson and Ed Wood of the same nationality?', answer='yes', gold_titles=\"{'Scott Derrickson', 'Ed Wood'}\")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def load_datasets():\n",
    "\n",
    "    trainset = HotPotQA(split=\"train\", size=20)\n",
    "    valset = HotPotQA(split=\"val\", size=50)\n",
    "    testset = HotPotQA(split=\"test\", size=50)\n",
    "    print(f\"trainset, valset: {len(trainset)}, {len(valset)}, example: {trainset[0]}\")\n",
    "    return trainset, valset, testset\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class AnswerData(adal.DataClass):\n",
    "    reasoning: str = field(\n",
    "        metadata={\"desc\": \"The reasoning to produce the answer\"},\n",
    "    )\n",
    "    answer: str = field(\n",
    "        metadata={\"desc\": \"The answer you produced\"},\n",
    "    )\n",
    "\n",
    "    __output_fields__ = [\"reasoning\", \"answer\"]\n",
    "\n",
    "\n",
    "dataset = HotPotQA(split=\"train\", size=20)\n",
    "print(dataset[0], type(dataset[0]))\n",
    "\n",
    "HotPotQAData(\n",
    "    id=\"5a8b57f25542995d1e6f1371\",\n",
    "    question=\"Were Scott Derrickson and Ed Wood of the same nationality?\",\n",
    "    answer=\"yes\",\n",
    "    gold_titles=\"{'Scott Derrickson', 'Ed Wood'}\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "id": "ZZIEtZYHNVjo"
   },
   "outputs": [],
   "source": [
    "class DspyRetriever(adal.Retriever):\n",
    "    def __init__(self, top_k: int = 3):\n",
    "        super().__init__()\n",
    "        self.top_k = top_k\n",
    "        self.dspy_retriever = dspy.Retrieve(k=top_k)\n",
    "\n",
    "    def call(\n",
    "        self, input: str, top_k: Optional[int] = None\n",
    "    ) -> List[adal.RetrieverOutput]:\n",
    "\n",
    "        k = top_k or self.top_k\n",
    "\n",
    "        output = self.dspy_retriever(query_or_queries=input, k=k)\n",
    "        final_output: List[RetrieverOutput] = []\n",
    "        documents = output.passages\n",
    "\n",
    "        final_output.append(\n",
    "            RetrieverOutput(\n",
    "                query=input,\n",
    "                documents=documents,\n",
    "                doc_indices=[],\n",
    "            )\n",
    "        )\n",
    "        return final_output\n",
    "\n",
    "\n",
    "def test_retriever():\n",
    "    question = \"How many storeys are in the castle that David Gregory inherited?\"\n",
    "    retriever = DspyRetriever(top_k=3)\n",
    "    retriever_out = retriever(input=question)\n",
    "    print(f\"retriever_out: {retriever_out}\")\n",
    "\n",
    "\n",
    "def call(\n",
    "    self, question: str, id: Optional[str] = None\n",
    ") -> Union[adal.GeneratorOutput, adal.Parameter]:\n",
    "    prompt_kwargs = self._prepare_input(question)\n",
    "    output = self.llm(prompt_kwargs=prompt_kwargs, id=id)\n",
    "    return output\n",
    "\n",
    "\n",
    "def call(self, question: str, id: str = None) -> adal.GeneratorOutput:\n",
    "    if self.training:\n",
    "        raise ValueError(\"This component is not supposed to be called in training mode\")\n",
    "\n",
    "    retriever_out = self.retriever.call(input=question)\n",
    "\n",
    "    successor_map_fn = lambda x: (  # noqa E731\n",
    "        \"\\n\\n\".join(x[0].documents) if x and x[0] and x[0].documents else \"\"\n",
    "    )\n",
    "    retrieved_context = successor_map_fn(retriever_out)\n",
    "\n",
    "    prompt_kwargs = {\n",
    "        \"context\": retrieved_context,\n",
    "        \"question\": question,\n",
    "    }\n",
    "\n",
    "    output = self.llm.call(\n",
    "        prompt_kwargs=prompt_kwargs,\n",
    "        id=id,\n",
    "    )\n",
    "    return output\n",
    "\n",
    "\n",
    "def forward(self, question: str, id: str = None) -> adal.Parameter:\n",
    "    if not self.training:\n",
    "        raise ValueError(\"This component is not supposed to be called in eval mode\")\n",
    "    retriever_out = self.retriever.forward(input=question)\n",
    "    successor_map_fn = lambda x: (  # noqa E731\n",
    "        \"\\n\\n\".join(x.data[0].documents)\n",
    "        if x.data and x.data[0] and x.data[0].documents\n",
    "        else \"\"\n",
    "    )\n",
    "    retriever_out.add_successor_map_fn(successor=self.llm, map_fn=successor_map_fn)\n",
    "    generator_out = self.llm.forward(\n",
    "        prompt_kwargs={\"question\": question, \"context\": retriever_out}, id=id\n",
    "    )\n",
    "    return generator_out\n",
    "\n",
    "\n",
    "def bicall(\n",
    "    self, question: str, id: str = None\n",
    ") -> Union[adal.GeneratorOutput, adal.Parameter]:\n",
    "    \"\"\"You can also combine both the forward and call in the same function.\n",
    "    Supports both training and eval mode by using __call__ for GradComponents\n",
    "    like Retriever and Generator\n",
    "    \"\"\"\n",
    "    retriever_out = self.retriever(input=question)\n",
    "    if isinstance(retriever_out, adal.Parameter):\n",
    "        successor_map_fn = lambda x: (  # noqa E731\n",
    "            \"\\n\\n\".join(x.data[0].documents)\n",
    "            if x.data and x.data[0] and x.data[0].documents\n",
    "            else \"\"\n",
    "        )\n",
    "        retriever_out.add_successor_map_fn(successor=self.llm, map_fn=successor_map_fn)\n",
    "    else:\n",
    "        successor_map_fn = lambda x: (  # noqa E731\n",
    "            \"\\n\\n\".join(x[0].documents) if x and x[0] and x[0].documents else \"\"\n",
    "        )\n",
    "        retrieved_context = successor_map_fn(retriever_out)\n",
    "    prompt_kwargs = {\n",
    "        \"context\": retrieved_context,\n",
    "        \"question\": question,\n",
    "    }\n",
    "    output = self.llm(prompt_kwargs=prompt_kwargs, id=id)\n",
    "    return output\n",
    "\n",
    "\n",
    "task_desc_str = r\"\"\"Answer questions with short factoid answers.\n",
    "\n",
    "You will receive context(may contain relevant facts) and a question.\n",
    "Think step by step.\"\"\"\n",
    "\n",
    "\n",
    "class VanillaRAG(adal.GradComponent):\n",
    "    def __init__(self, passages_per_hop=3, model_client=None, model_kwargs=None):\n",
    "        super().__init__()\n",
    "\n",
    "        self.passages_per_hop = passages_per_hop\n",
    "\n",
    "        self.retriever = DspyRetriever(top_k=passages_per_hop)\n",
    "        self.llm_parser = adal.DataClassParser(\n",
    "            data_class=AnswerData, return_data_class=True, format_type=\"json\"\n",
    "        )\n",
    "        self.llm = Generator(\n",
    "            model_client=model_client,\n",
    "            model_kwargs=model_kwargs,\n",
    "            prompt_kwargs={\n",
    "                \"task_desc_str\": adal.Parameter(\n",
    "                    data=task_desc_str,\n",
    "                    role_desc=\"Task description for the language model\",\n",
    "                    param_type=adal.ParameterType.PROMPT,\n",
    "                ),\n",
    "                \"few_shot_demos\": adal.Parameter(\n",
    "                    data=None,\n",
    "                    requires_opt=True,\n",
    "                    role_desc=\"To provide few shot demos to the language model\",\n",
    "                    param_type=adal.ParameterType.DEMOS,\n",
    "                ),\n",
    "                \"output_format_str\": self.llm_parser.get_output_format_str(),\n",
    "            },\n",
    "            template=answer_template,\n",
    "            output_processors=self.llm_parser,\n",
    "            use_cache=True,\n",
    "        )\n",
    "\n",
    "\n",
    "class VallinaRAGAdal(adal.AdalComponent):\n",
    "    def __init__(\n",
    "        self,\n",
    "        model_client: adal.ModelClient,\n",
    "        model_kwargs: Dict,\n",
    "        backward_engine_model_config: Dict | None = None,\n",
    "        teacher_model_config: Dict | None = None,\n",
    "        text_optimizer_model_config: Dict | None = None,\n",
    "    ):\n",
    "        task = VanillaRAG(\n",
    "            model_client=model_client,\n",
    "            model_kwargs=model_kwargs,\n",
    "            passages_per_hop=3,\n",
    "        )\n",
    "        eval_fn = AnswerMatchAcc(type=\"fuzzy_match\").compute_single_item\n",
    "        loss_fn = adal.EvalFnToTextLoss(\n",
    "            eval_fn=eval_fn, eval_fn_desc=\"fuzzy_match: 1 if str(y) in str(y_gt) else 0\"\n",
    "        )\n",
    "        super().__init__(\n",
    "            task=task,\n",
    "            eval_fn=eval_fn,\n",
    "            loss_fn=loss_fn,\n",
    "            backward_engine_model_config=backward_engine_model_config,\n",
    "            teacher_model_config=teacher_model_config,\n",
    "            text_optimizer_model_config=text_optimizer_model_config,\n",
    "        )\n",
    "\n",
    "    # tell the trainer how to call the task\n",
    "    def prepare_task(self, sample: HotPotQAData) -> Tuple[Callable[..., Any], Dict]:\n",
    "        if self.task.training:\n",
    "            return self.task.forward, {\"question\": sample.question, \"id\": sample.id}\n",
    "        else:\n",
    "            return self.task.call, {\"question\": sample.question, \"id\": sample.id}\n",
    "\n",
    "    # eval mode: get the generator output, directly engage with the eval_fn\n",
    "    def prepare_eval(self, sample: HotPotQAData, y_pred: adal.GeneratorOutput) -> float:\n",
    "        y_label = \"\"\n",
    "        if y_pred and y_pred.data and y_pred.data.answer:\n",
    "            y_label = y_pred.data.answer\n",
    "        return self.eval_fn, {\"y\": y_label, \"y_gt\": sample.answer}\n",
    "\n",
    "    # train mode: get the loss and get the data from the full_response\n",
    "    def prepare_loss(self, sample: HotPotQAData, pred: adal.Parameter):\n",
    "        # prepare gt parameter\n",
    "        y_gt = adal.Parameter(\n",
    "            name=\"y_gt\",\n",
    "            data=sample.answer,\n",
    "            eval_input=sample.answer,\n",
    "            requires_opt=False,\n",
    "        )\n",
    "\n",
    "        # pred's full_response is the output of the task pipeline which is GeneratorOutput\n",
    "        pred.eval_input = (\n",
    "            pred.full_response.data.answer\n",
    "            if pred.full_response\n",
    "            and pred.full_response.data\n",
    "            and pred.full_response.data.answer\n",
    "            else \"\"\n",
    "        )\n",
    "        return self.loss_fn, {\"kwargs\": {\"y\": pred, \"y_gt\": y_gt}}\n",
    "\n",
    "\n",
    "def train_diagnose(\n",
    "    model_client: adal.ModelClient,\n",
    "    model_kwargs: Dict,\n",
    ") -> Dict:\n",
    "\n",
    "    trainset, valset, testset = load_datasets()\n",
    "\n",
    "    adal_component = VallinaRAGAdal(\n",
    "        model_client,\n",
    "        model_kwargs,\n",
    "        backward_engine_model_config=gpt_4o_model,\n",
    "        teacher_model_config=gpt_3_model,\n",
    "        text_optimizer_model_config=gpt_3_model,\n",
    "    )\n",
    "    trainer = adal.Trainer(adaltask=adal_component)\n",
    "    trainer.diagnose(dataset=trainset, split=\"train\")\n",
    "    # trainer.diagnose(dataset=valset, split=\"val\")\n",
    "    # trainer.diagnose(dataset=testset, split=\"test\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AmkbyxmuruUu"
   },
   "source": [
    "# Issues and feedback\n",
    "\n",
    "If you encounter any issues, please report them here: [GitHub Issues](https://github.com/SylphAI-Inc/LightRAG/issues).\n",
    "\n",
    "For feedback, you can use either the [GitHub discussions](https://github.com/SylphAI-Inc/LightRAG/discussions) or [Discord](https://discord.gg/ezzszrRZvT)."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
